%reload_ext autoreload
%autoreload 2
from torch_snippets import *
from torch_snippets.markup2 import AD
reset_logger()
# 0. load config
config = AD(
pretrained_model_path="checkpoints/stable-diffusion-v1-5/",
checkpoint_folder="outputs/train_stage_2_v1-2023-12-17T18-31-50/",
checkpoint_file="outputs/train_stage_2_v1-2023-12-17T18-31-50/checkpoints/checkpoint-epoch-88.ckpt",
clip_model_path="checkpoints/clip-vit-base-patch32/",
noise_kwargs = AD(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
steps_offset=1,
clip_sample=False
),
num_inference_steps=25,
image=AD(
source_image="/home/ubuntu/data/animate-anyone/TikTok_dataset/00001/images/0001.png",
size=256,
),
video=AD(
video_path="/home/ubuntu/data/animate-anyone/TikTok_dataset/00003/00003_dwpose.mp4",
max_length=24,
offset=1
)
)
# 1. load models
from pipelines.pipeline_stage_2 import AnimationAnyonePipeline, DDIMScheduler
from utils.load_models import load_models_stage_2
torch.set_grad_enabled(False) # no need for grad computations
%time models = load_models_stage_2(config)
aapipe = AnimationAnyonePipeline(
vae=models.vae,
text_encoder=models.text_encoder,
tokenizer=models.tokenizer,
unet=models.unet,
referencenet=models.referencenet,
scheduler=DDIMScheduler(**config.noise_kwargs),
)
_ = aapipe.to('cuda')
# 2. Load video and image
from utils.videoreader import VideoReader
size = config.image.size
max_length = config.video.max_length
offset=config.video.offset
source_image = config.image.source_image
source_image = resize(read(source_image, 1), size)
video_path=config.video.video_path
control = VideoReader(video_path).read()
offset = 1
if control[0].shape[0] != size:
control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
if max_length is not None:
control = control[offset:(offset+max_length)]
control = np.array(control)
torch.Tensor(source_image), torch.Tensor(control)
# 2. create empty latents and timesteps
dtype = aapipe.unet.dtype
device = aapipe.unet.device
generator = torch.Generator(device=aapipe.unet.device)
generator.manual_seed(torch.initial_seed())
noisy_latents = aapipe.prepare_latents(
batch_size=1,
num_channels_latents=4,
video_length=24,
height=size,
width=size,
dtype=dtype,
device=device,
generator=generator,
latents=None,
clip_length=16
)
extra_step_kwargs = aapipe.prepare_extra_step_kwargs(generator, eta=0.0)
aapipe.scheduler.set_timesteps(config.num_inference_steps, device=device)
timesteps = aapipe.scheduler.timesteps
noisy_latents, timesteps
# 3. Setup
from models.ReferenceNet_attention import ReferenceNetAttention
reference_control_writer = ReferenceNetAttention(aapipe.referencenet, do_classifier_free_guidance=False, mode='write', fusion_blocks='full', is_image=False)
reference_control_reader = ReferenceNetAttention(aapipe.unet, do_classifier_free_guidance=False, mode='read', fusion_blocks='full', is_image=False)
# 4. Make source image vae-latents for referencenet and clip-embeddings for referencenet and unet
source_image_latents = aapipe.images2latents(source_image[None,:], dtype=dtype)
source_image_clip = models.clip_image_processor(images=Image.fromarray(source_image).convert('RGB'), return_tensors="pt").pixel_values.to(device=device)
source_image_clip_embeddings = models.clip_image_encoder(source_image_clip).unsqueeze(1).to(device=device, dtype=dtype)
source_image_latents, source_image_clip_embeddings
# 5. Make pose latents using poseguider for unet
#### pose condition ####
pixel_transforms = transforms.Compose([
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
pose_condition = torch.from_numpy(control.copy()).to(device=device, dtype=dtype).permute(0, 3, 1, 2) / 255.0
pose_condition = pixel_transforms(pose_condition)
pose_latents = models.poseguider(pose_condition)
pose_latents
reference_control_reader.clear()
reference_control_writer.clear()
aapipe.referencenet(source_image_latents, timesteps[0], source_image_clip_embeddings)
reference_control_reader.update(reference_control_writer)
t = timesteps[0]
unet_input_latents = aapipe.scheduler.scale_model_input(noisy_latents, t) + pose_latents
unet_input_latents
unet_noise_pred = aapipe.unet(unet_input_latents[None].permute(0, 2, 1, 3, 4), t, encoder_hidden_states=source_image_clip_embeddings).sample[0].permute(1,0,2,3)
unet_noise_pred
unet_output_latents = aapipe.scheduler.step(unet_noise_pred, t, unet_input_latents, **extra_step_kwargs, return_dict=False)[0]
unet_output_latents
video = aapipe.decode_latents(unet_output_latents.detach(), rank=0)
torch.Tensor(video)
unet_output_latents.mean((1,2,3)).v
torch.Tensor(video).mean((1,2,3)).v
from utils.show_video import animate
animate(video)